Gan Loss improve之LSGan WGan WGan-GP

导言:接上篇经典Gan Loss原理以及其实现,本篇深入探讨一下各种GanLoss。

经典GAN loss缺点:

image|641x377,75%

在经典GAN LOSS中Disciminator是一个Binery Classifier, 对于图中所示的两种分布,二分类判别器给出的两者loss是相同的,但是实际上的损失应该要小一点(距离近).

二分类判别器给出的两者loss是相同的

解释:当Disciminator足够强时,其最后一层的output 经过softmax得到是属于class 1 和class 2的概率。分别是1(0),0(1).这样每一次都是log2(js div对于不重合的分布距离大小)

LS Gan:

将二分discriminator从分类损失转为回归损失即可。

1
2
3
4
5
6
7
8
9
10
...
# !!! Minimizes MSE instead of BCE
adversarial_loss = torch.nn.MSELoss()
...
g_loss = adversarial_loss(discriminator(gen_imgs), valid)
...
real_loss = adversarial_loss(discriminator(real_imgs), valid)
fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
d_loss = 0.5 * (real_loss + fake_loss)
...

WGan:

WGAN 分布距离原理

一个热力图对应了一个moving plan,将转移到分布上。实际WGAN loss需要解一个最优化问题找到最优的来使地B最小。

公式的推导李宏毅老师没有讲解,本篇也不赘述。直接给出其形式:

WGAN 损失公式

Weight clipping

对于Dis中每一个参数其大小需要平滑变换,不能变化太大,太大的或太小进行裁剪。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
...
loss_D = -torch.mean(discriminator(real_imgs)) + torch.mean(discriminator(fake_imgs))
loss_D.backward()
optimizer_D.step()
for p in discriminator.parameters():
p.data.clamp_(-opt.clip_value, opt.clip_value)

# Train the generator every n_critic iterations
if i % opt.n_critic == 0:
...

loss_G = -torch.mean(discriminator(gen_imgs))
loss_G.backward()
optimizer_G.step()

WGAN-GP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48

def calc_gradient_penalty(netD, real_data, fake_data, LAMBDA, device):
#print real_data.size()
alpha = torch.rand(1, 1)
alpha = alpha.expand(real_data.size())
alpha = alpha.to(device)#cuda() #gpu) #if use_cuda else alpha

interpolates = alpha * real_data + ((1 - alpha) * fake_data)


interpolates = interpolates.to(device)#.cuda()
interpolates = torch.autograd.Variable(interpolates, requires_grad=True)

disc_interpolates = netD(interpolates)

gradients = torch.autograd.grad(outputs=disc_interpolates, inputs=interpolates,
grad_outputs=torch.ones(disc_interpolates.size()).to(device),#.cuda(), #if use_cuda else torch.ones(
#disc_interpolates.size()),
create_graph=True, retain_graph=True, only_inputs=True)[0]
#LAMBDA = 1
gradient_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean() * LAMBDA
return gradient_penalty



### netD ####
output = netD(real).to(opt.device)
#D_real_map = output.detach()
errD_real = -output.mean()

fake = netG(noise)
output = netD(fake.detach())
errD_fake = output.mean()
errD_fake.backward(retain_graph=True)

gradient_penalty = functions.calc_gradient_penalty(netD, real, fake, opt.lambda_grad, opt.device)
gradient_penalty.backward()

errD = errD_real + errD_fake + gradient_penalty # the sum Loss is defined here to plot. gradient_penalty: WGan Loss
optimizerD.step()



#### netG ####
netG.zero_grad()
output = netD(fake)
errG = -output.mean()
errG.backward(retain_graph=True)